#include "SteeringArrive.h"

#include <cassert>

#include "ai/steering/SteeringOutput.h"
#include "math/Point3.h"
#include "math/Utils.h"
#include "physic/PhysicConstants.h"

ArriveData::ArriveData(const Point3* srcPos /*= nullptr*/,
                       const Vector3* srcLinearVel /*= nullptr*/,
                       const Point3* targetPos /*= nullptr*/,
                       const float targetRadius /*= 0.0f*/,
                       const float slowDownRadius /*0.0f*/,
                       const float timeToTarget /*= 0.0f*/)
                       : mSrcPos(srcPos)
                       , mSrcLinearVel(srcLinearVel)
                       , mTargetPos(targetPos)
                       , mTargetRadius(targetRadius)
                       , mSlowDownRadius(slowDownRadius)
                       , mTimeToTarget(timeToTarget)
{
    assert(mTargetRadius > 0.0f || areEquals(mTargetRadius, 0.0f));
    assert(mSlowDownRadius > 0.0f || areEquals(mSlowDownRadius, 0.0f));
    assert(mTimeToTarget > 0.0f || areEquals(mTimeToTarget, 0.0f));
}

void ArriveData::set(const Point3& srcPos,
         const Vector3& srcLinearVel,
         const Point3& targetPos,
         const float targetRadius,
         const float slowDownRadius,
         const float timeToTarget)
{
    assert(targetRadius > 0.0f || areEquals(targetRadius, 0.0f));
    assert(slowDownRadius > 0.0f || areEquals(slowDownRadius, 0.0f));
    assert(timeToTarget > 0.0f || areEquals(timeToTarget, 0.0f));

    mSrcPos = &srcPos;
    mSrcLinearVel = &srcLinearVel;
    mTargetPos = &targetPos;        
    mTargetRadius = targetRadius;
    mSlowDownRadius = slowDownRadius;
    mTimeToTarget = timeToTarget;
}

void arrive(const ArriveData * const data, 
            SteeringOutput * const outputs,
            const uint32_t numData)
{
    assert(data);
    assert(outputs);
    assert(numData > 0);

    // PARALLEL_FOR
    for (size_t i = 0; i < numData; ++i) {
        const Point3& srcPos = *data[i].mSrcPos;
        const Vector3& srcLinearVel = *data[i].mSrcLinearVel;
        const Point3& targetPos = *data[i].mTargetPos;
        const float targetRadius = data[i].mTargetRadius;
        const float slowDownRadius = data[i].mSlowDownRadius;
        const float timeToTarget = data[i].mTimeToTarget;
        SteeringOutput& output = outputs[i];

        assert(targetRadius > 0.0f || areEquals(targetRadius, 0.0f));
        assert(slowDownRadius > 0.0f || areEquals(slowDownRadius, 0.0f));
        assert(timeToTarget > 0.0f || areEquals(timeToTarget, 0.0f));

        // Get the direction from source to target
        // and its distance
        Vector3 toTarget;
        toTarget.set(srcPos, targetPos);
        const float distance =  toTarget.length();

        // If we are outside the slow down radius, 
        // then we will try to advance at full speed.
        // Otherwise calculate a scaled speed.
        const float targetSpeed = 
            (slowDownRadius < distance) ? MAX_SPEED
            : MAX_SPEED * distance * (1.0f / slowDownRadius);

        // The target velocity combines target speed and direction.
        Vector3 targetVelocity = toTarget;
        normalize(targetVelocity, targetVelocity);
        targetVelocity *= targetSpeed;

        // Steering linear velocity tries to get
        // to the target velocity
        output.mLinearVel = targetVelocity;
        output.mLinearVel -= srcLinearVel;
        output.mLinearVel *= 1.0f;
        // Clamp linear velocity to max acceleration if necessary
        {
            const float sqrLen = output.mLinearVel.sqrLength();
            if (MAX_SPEED * MAX_SPEED < sqrLen) {
                output.mLinearVel.setLength(MAX_SPEED);
            }
        }
    }
}